{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "快速指导\n",
    "======\n",
    "启动和运行Transformers！无论你是开发者还是日常用户，这个快速指导将帮助你入门，并告诉你如何使用pipeline()用于推理，用AutoClass加载一个预训练的模型和预处理器，用PyTorch或TensorFlow快速训练一个模型。如果你是一个初学者，我们建议查看我们的教程或下一个课程，以便对这里介绍的概念进行更深入的了解。\n",
    "\n",
    "在你开始之前，确保你已经安装了所有必要的库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "!pip install transformers datasets"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你还需要安装你喜欢的机器学习框架："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "pip install torch"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pipeline\n",
    "\n",
    "pipeline()是使用预训练模型进行推理的最简单和最快的方法。你可以将pipeline()开箱即用，用于不同模式的许多任务，其中一些任务如下表所示：\n",
    "\n",
    "| **Task**                     | **Description**                                                                                              | **Modality**    | **Pipeline identifier**                       |\n",
    "|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|-----------------------------------------------|\n",
    "| Text classification          | assign a label to a given sequence of text                                                                   | NLP             | pipeline(task=“sentiment-analysis”)           |\n",
    "| Text generation              | generate text given a prompt                                                                                 | NLP             | pipeline(task=“text-generation”)              |\n",
    "| Summarization                | generate a summary of a sequence of text or document                                                         | NLP             | pipeline(task=“summarization”)                |\n",
    "| Image classification         | assign a label to an image                                                                                   | Computer vision | pipeline(task=“image-classification”)         |\n",
    "| Image segmentation           | assign a label to each individual pixel of an image (supports semantic, panoptic, and instance segmentation) | Computer vision | pipeline(task=“image-segmentation”)           |\n",
    "| Object detection             | predict the bounding boxes and classes of objects in an image                                                | Computer vision | pipeline(task=“object-detection”)             |\n",
    "| Audio classification         | assign a label to some audio data                                                                            | Audio           | pipeline(task=“audio-classification”)         |\n",
    "| Automatic speech recognition | transcribe speech into text                                                                                  | Audio           | pipeline(task=“automatic-speech-recognition”) |\n",
    "| Visual question answering    | answer a question about the image, given an image and a question                                             | Multimodal      | pipeline(task=“vqa”)                          |\n",
    "| Document question answering  | answer a question about a document, given an image and a question                                            | Multimodal      | pipeline(task=\"document-question-answering\")  |\n",
    "| Image captioning             | generate a caption for a given image                                                                         | Multimodal      | pipeline(task=\"image-to-text\")                |\n",
    "\n",
    "首先创建一个pipeline()的实例，并指定一个你想使用它的任务。在本指南中，你将以情感分析的pipeline()为例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "classifier = pipeline(\"sentiment-analysis\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "pipeline()下载并缓存了一个默认的预训练模型和标记器，用于情感分析。现在你可以在你的目标文本上使用该分类器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "classifier(\"We are very happy to show you the 🤗 Transformers library.\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[{'label': 'POSITIVE', 'score': 0.9998}]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果你有一个以上的输入，把你的输入作为一个列表传给pipeline()，以返回一个字典的列表："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "results = classifier([\"We are very happy to show you the 🤗 Transformers library.\", \"We hope you don't hate it.\"])\n",
    "    for result in results:\n",
    "        print(f\"label: {result['label']}, with score: {round(result['score'], 4)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "label: POSITIVE, with score: 0.9998\n",
    "\n",
    "label: NEGATIVE, with score: 0.5309\n",
    "\n",
    "pipeline()也可以为你喜欢的任何任务在整个数据集上迭代。在这个例子中，让我们选择自动语音识别作为我们的任务："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import pipeline\n",
    "\n",
    "speech_recognizer = pipeline(\"automatic-speech-recognition\", model=\"facebook/wav2vec2-base-960h\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载一个你想迭代的音频数据集（详见《数据集快速入门》）。例如，加载MInDS-14数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset, Audio\n",
    "\n",
    "dataset = load_dataset(\"PolyAI/minds14\", name=\"en-US\", split=\"train\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你需要确保数据集的采样率与facebook/wav2vec2-base-960h训练的采样率一致："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=speech_recognizer.feature_extractor.sampling_rate))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "调用 \"音频 \"列时，音频文件被自动加载并重新采样。从前4个样本中提取原始波形阵列，并将其作为一个列表传递给管道："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "result = speech_recognizer(dataset[:4][\"audio\"])\n",
    "print([d[\"text\"] for d in result])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "['I WOULD LIKE TO SET UP A JOINT ACCOUNT WITH MY PARTNER HOW DO I PROCEED WITH DOING THAT', \"FONDERING HOW I'D SET UP A JOIN TO HELL T WITH MY WIFE AND WHERE THE AP MIGHT BE\", \"I I'D LIKE TOY SET UP A JOINT ACCOUNT WITH MY PARTNER I'M NOT SEEING THE OPTION TO DO IT ON THE APSO I CALLED IN TO GET SOME HELP CAN I JUST DO IT OVER THE PHONE WITH YOU AND GIVE YOU THE INFORMATION OR SHOULD I DO IT IN THE AP AN I'M MISSING SOMETHING UQUETTE HAD PREFERRED TO JUST DO IT OVER THE PHONE OF POSSIBLE THINGS\", 'HOW DO I FURN A JOINA COUT']"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于输入量较大的数据集（如语音或视觉），你要传递一个生成器而不是一个列表，以便在内存中加载所有的输入。更多信息请看 pipeline API reference。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在pipeline中使用另一个模型和标记器\n",
    "pipeline可以容纳Hub中的任何模型，这使得pipeline很容易适应其他情况。例如，如果你想要一个能够处理法语文本的模型，请使用枢纽上的标签来筛选合适的模型。最上面的过滤结果返回了一个多语言的BERT模型，该模型为情感分析进行了微调，你可以用于法语文本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "model_name = \"nlptown/bert-base-multilingual-uncased-sentiment\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用AutoModelForSequenceClassification和AutoTokenizer来加载预训练的模型和它相关的标记器（在下一节有更多关于AutoClass的内容）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在pipeline()中指定模型和标记器，现在你可以对法语文本应用分类器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "classifier = pipeline(\"sentiment-analysis\", model=model, tokenizer=tokenizer)\n",
    "classifier(\"Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[{'label': '5 stars', 'score': 0.7273}]\n",
    "\n",
    "如果你找不到适合你的用例的模型，你需要在你的数据上微调一个预训练的模型。看一下我们的微调教程，了解如何进行微调。最后，在你对你的预训练模型进行微调后，请考虑在Hub上与社区分享该模型！"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AutoClass\n",
    "\n",
    "在整个构架下，AutoModelForSequenceClassification和AutoTokenizer类一起工作，为之前提到的pipeline()提供动力。AutoClass是一个快捷方式，它可以从一个预训练模型的名称或路径中自动检索出该模型的架构。你只需要为你的任务和它的相关预处理类选择合适的自动类。\n",
    "\n",
    "让我们回到上一节的例子，看看如何使用自动类来重现pipeline()的结果。\n",
    "\n",
    "### AutoTokenizer\n",
    "\n",
    "AutoTokenizer负责将文本预处理成一个数字阵列，作为模型的输入。有多种规则制约着标记化过程，包括如何拆分一个词以及应该在哪个层次上拆分一个词（在标记化器摘要中了解更多关于标记化的信息）。最重要的一点是，你需要用相同的模型名称来实例化一个标记化器，以确保你使用的是与模型预训练时相同的标记化规则。\n",
    "\n",
    "用AutoTokenizer加载一个标记器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_name = \"nlptown/bert-base-multilingual-uncased-sentiment\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将你的文本传递给标记器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "encoding = tokenizer(\"We are very happy to show you the 🤗 Transformers library.\")\n",
    "print(encoding)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{'input_ids': [101, 11312, 10320, 12495, 19308, 10114, 11391, 10855, 10103, 100, 58263, 13299, 119, 102],\n",
    " 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
    " 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n",
    "\n",
    " 该标记器的返回包含以下字典：\n",
    " >input_ids：代币的数字编码。\n",
    "\n",
    " >attention_mask：表示哪些标记应该被关注。\n",
    "\n",
    "标记器也可以接受一个输入批次列表，并对文本进行填充和截断，以返回一个具有统一长度的数字阵列："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "pt_batch = tokenizer(\n",
    "    [\"We are very happy to show you the 🤗 Transformers library.\", \"We hope you don't hate it.\"],\n",
    "    padding=True,\n",
    "    truncation=True,\n",
    "    max_length=512,\n",
    "    return_tensors=\"pt\",\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">请查看预处理教程，了解更多关于标记化的细节，以及如何使用AutoImageProcessor、AutoFeatureExtractor和AutoProcessor来预处理图像、音频和多模态输入。\n",
    "\n",
    "### AutoModel\n",
    "\n",
    "Transformers提供了一个简单而统一的方式来加载预训练的实例。这意味着你可以像加载AutoTokenizer一样加载一个AutoModel。唯一的区别是为该任务选择正确的AutoModel。对于文本（或序列）分类，你应该加载AutoModelForSequenceClassification："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model_name = \"nlptown/bert-base-multilingual-uncased-sentiment\"\n",
    "pt_model = AutoModelForSequenceClassification.from_pretrained(model_name)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">关于AutoModel类所支持的任务，请参见任务摘要。\n",
    "\n",
    "现在把你预处理过的数字阵列输入直接传递给模型。你只需要在字典中加入**来解包："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "pt_outputs = pt_model(**pt_batch)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "该模型在logits属性中输出最终的张量。对logits应用softmax函数来检索概率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "pt_predictions = nn.functional.softmax(pt_outputs.logits, dim=-1)\n",
    "print(pt_predictions)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tensor([[0.0021, 0.0018, 0.0115, 0.2121, 0.7725],\n",
    "        [0.2084, 0.1826, 0.1969, 0.1755, 0.2365]], grad_fn=<SoftmaxBackward0>)\n",
    "\n",
    ">所有Transformers模型（PyTorch或TensorFlow）在最终激活函数（如softmax）之前输出张量，因为最终激活函数通常与损失融合。模型输出是特殊的数据类，所以它们的属性在IDE中是自动完成的。模型输出的行为就像一个元组或一个字典（你可以用一个整数、一个片断或一个字符串作为索引），在这种情况下，属性为None的会被忽略。\n",
    "\n",
    "### 保存一个模型\n",
    "\n",
    "一旦你的模型被微调，你可以使用PreTrainedModel.save_pretrained()将它与它的标记器一起保存："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "pt_save_directory = \"./pt_save_pretrained\"\n",
    "tokenizer.save_pretrained(pt_save_directory)\n",
    "pt_model.save_pretrained(pt_save_directory)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当你准备再次使用该模型时，用PreTrainedModel.from_pretrained()重新加载它："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "pt_model = AutoModelForSequenceClassification.from_pretrained(\"./pt_save_pretrained\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个特别酷的Transformers功能是能够保存一个模型并将其重新加载为PyTorch或TensorFlow模型。from_pt或from_tf参数可以将模型从一个框架转换到另一个框架："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModel\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)\n",
    "pt_model = AutoModelForSequenceClassification.from_pretrained(tf_save_directory, from_tf=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型定制\n",
    "\n",
    "你可以修改模型的配置类来改变模型的构建方式。配置指定了一个模型的属性，如隐藏层的数量或注意头。当你从一个自定义配置类中初始化一个模型时，你将从头开始训练模型。模型的属性是随机初始化的，你需要在通过它获得有意义的结果之前训练该模型。\n",
    "\n",
    "首先导入AutoConfig，然后加载你想修改的预训练模型。在AutoConfig.from_pretrained()中，你可以指定你想改变的属性，比如说注意头的数量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoConfig\n",
    "\n",
    "my_config = AutoConfig.from_pretrained(\"distilbert-base-uncased\", n_heads=12)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用AutoModel.from_config()从你的自定义配置中创建一个模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModel\n",
    "\n",
    "my_model = AutoModel.from_config(my_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "看一下创建自定义架构指南，了解更多关于构建自定义配置的信息。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练器 - 一个PyTorch优化的训练循环\n",
    "\n",
    "所有的模型都是一个标准的torch.nn.Module，所以你可以在任何典型的训练循环中使用它们。虽然你可以编写你自己的训练循环，但是Transformers为PyTorch提供了一个训练器类，它包含基本的训练循环，并为分布式训练、混合精度等特性增加了额外的功能。\n",
    "\n",
    "根据你的任务，你通常会向Trainer传递以下参数：\n",
    "\n",
    "1.一个 PreTrainedModel 或者 一个 torch.nn.Module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.训练参数（TrainingArguments）包含了你可以改变的模型超参数，如学习率、批次大小和训练的历时数。如果你没有指定任何训练参数，则使用默认值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"path/to/save/folder/\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    num_train_epochs=2,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.一个预处理类，如标记器、图像处理器、特征提取器或处理器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.加载一个数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"rotten_tomatoes\")  # doctest: +IGNORE_RESULT"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5.创建一个函数，对数据集进行标记："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "def tokenize_dataset(dataset):\n",
    "    return tokenizer(dataset[\"text\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后用map函数在整个数据集上应用它："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "dataset = dataset.map(tokenize_dataset, batched=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "6.一个DataCollatorWithPadding，用于从你的数据集中创建一批例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，将所有这些类作为参数传递给Trainer："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"test\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")  # doctest: +SKIP"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当你准备好后，调用train()开始训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">对于使用sequence-to-sequence模型的任务--如翻译或总结--请使用Seq2SeqTrainer和Seq2SeqTrainingArguments类代替。\n",
    "\n",
    "你可以通过对Trainer内部的方法进行子类化来定制训练循环的行为。这使得你可以定制一些功能，如损失函数、优化器和调度器。看看Trainer的参考资料，看看哪些方法可以被子类化。\n",
    "\n",
    "另一种定制训练循环的方法是使用回调函数（Callbacks）。你可以使用回调来与其他库集成，并检查训练循环以报告进度或提前停止训练。回调不会修改训练循环本身的任何内容。要定制像损失函数这样的东西，你需要对训练器进行子类化。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 下一步做什么？\n",
    "\n",
    "现在你已经完成了Transformers的快速浏览，看看我们的指南，学习如何做更具体的事情，如编写一个自定义模型，为一个任务微调模型，以及如何用脚本训练模型。如果你有兴趣了解更多关于Transformers的核心概念，请拿起一杯咖啡，看看我们的概念指南!"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
